package org.jvm.nativemethod.methods.sun.misc;

import org.jvm.instruction.base.InstructionUtil;
import org.jvm.nativemethod.NativeMethodRegister;
import org.jvm.rtda.Object;
import org.jvm.rtda.*;
import org.jvm.rtda.heap.Klass;
import org.jvm.rtda.heap.classLoader.KlassLoaderRegister;
import org.jvm.rtda.heap.classmember.Field;
import org.jvm.rtda.thread.*;
import org.jvm.util.JvmUtil;

/**
 * @author 海燕
 * @date 2023/2/26
 */
public class UnsafeNativeMethod extends NativeMethodRegister {

    public static void init() throws NoSuchMethodException {
        register("sun/misc/Unsafe", "allocateMemory", "(J)J", UnsafeNativeMethod.class.getDeclaredMethod("allocateMemory", Frame.class));
        register("sun/misc/Unsafe", "reallocateMemory", "(JJ)J", UnsafeNativeMethod.class.getDeclaredMethod("reallocateMemory", Frame.class));
        register("sun/misc/Unsafe", "freeMemory", "(J)V", UnsafeNativeMethod.class.getDeclaredMethod("freeMemory", Frame.class));
        register("sun/misc/Unsafe", "addressSize", "()I", UnsafeNativeMethod.class.getDeclaredMethod("addressSize", Frame.class));
        register("sun/misc/Unsafe", "getByte", "(J)B", UnsafeNativeMethod.class.getDeclaredMethod("mem_getByte", Frame.class));
        register("sun/misc/Unsafe", "putLong", "(JJ)V", UnsafeNativeMethod.class.getDeclaredMethod("mem_putLong", Frame.class));
        register("sun/misc/Unsafe", "arrayBaseOffset", "(Ljava/lang/Class;)I", UnsafeNativeMethod.class.getDeclaredMethod("arrayBaseOffset", Frame.class));
        register("sun/misc/Unsafe", "arrayIndexScale", "(Ljava/lang/Class;)I", UnsafeNativeMethod.class.getDeclaredMethod("arrayIndexScale", Frame.class));
        register("sun/misc/Unsafe", "addressSize", "()I", UnsafeNativeMethod.class.getDeclaredMethod("addressSize", Frame.class));
        register("sun/misc/Unsafe", "objectFieldOffset", "(Ljava/lang/reflect/Field;)J", UnsafeNativeMethod.class.getDeclaredMethod("objectFieldOffset", Frame.class));
        register("sun/misc/Unsafe", "compareAndSwapObject", "(Ljava/lang/Object;JLjava/lang/Object;Ljava/lang/Object;)Z", UnsafeNativeMethod.class.getDeclaredMethod("compareAndSwapObject", Frame.class));
        register("sun/misc/Unsafe", "getIntVolatile", "(Ljava/lang/Object;J)I", UnsafeNativeMethod.class.getDeclaredMethod("getInt", Frame.class));
        register("sun/misc/Unsafe", "compareAndSwapInt", "(Ljava/lang/Object;JII)Z", UnsafeNativeMethod.class.getDeclaredMethod("compareAndSwapInt", Frame.class));
        register("sun/misc/Unsafe", "getObjectVolatile", "(Ljava/lang/Object;J)Ljava/lang/Object;", UnsafeNativeMethod.class.getDeclaredMethod("getObject", Frame.class));
        register("sun/misc/Unsafe", "compareAndSwapLong", "(Ljava/lang/Object;JJJ)Z", UnsafeNativeMethod.class.getDeclaredMethod("compareAndSwapLong", Frame.class));
        register("sun/misc/Unsafe", "allocateInstance", "(Ljava/lang/Class;)Ljava/lang/Object;", UnsafeNativeMethod.class.getDeclaredMethod("allocateInstance", Frame.class));
        register("sun/misc/Unsafe", "defineClass", "(Ljava/lang/String;[BIILjava/lang/ClassLoader;Ljava/security/ProtectionDomain;)Ljava/lang/Class;", UnsafeNativeMethod.class.getDeclaredMethod("defineClass", Frame.class));
        register("sun/misc/Unsafe", "shouldBeInitialized", "(Ljava/lang/Class;)Z", UnsafeNativeMethod.class.getDeclaredMethod("shouldBeInitialized", Frame.class));
        register("sun/misc/Unsafe", "ensureClassInitialized", "(Ljava/lang/Class;)V", UnsafeNativeMethod.class.getDeclaredMethod("ensureClassInitialized", Frame.class));
        register("sun/misc/Unsafe", "staticFieldOffset", "(Ljava/lang/reflect/Field;)J", UnsafeNativeMethod.class.getDeclaredMethod("staticFieldOffset", Frame.class));
        register("sun/misc/Unsafe", "staticFieldBase", "(Ljava/lang/reflect/Field;)Ljava/lang/Object;", UnsafeNativeMethod.class.getDeclaredMethod("staticFieldBase", Frame.class));
        register("sun/misc/Unsafe", "putObjectVolatile", "(Ljava/lang/Object;JLjava/lang/Object;)V", UnsafeNativeMethod.class.getDeclaredMethod("putObject", Frame.class));
        register("sun/misc/Unsafe", "getOrderedObject", "(Ljava/lang/Object;J)Ljava/lang/Object;", UnsafeNativeMethod.class.getDeclaredMethod("getObject", Frame.class));
        register("sun/misc/Unsafe", "putOrderedObject", "(Ljava/lang/Object;JLjava/lang/Object;)V", UnsafeNativeMethod.class.getDeclaredMethod("putObject", Frame.class));
        register("sun/misc/Unsafe", "getIntVolatile", "(Ljava/lang/Object;J)I", UnsafeNativeMethod.class.getDeclaredMethod("getInt", Frame.class));
        register("sun/misc/Unsafe", "getLongVolatile", "(Ljava/lang/Object;J)J", UnsafeNativeMethod.class.getDeclaredMethod("getLong", Frame.class));
        register("sun/misc/Unsafe", "getBooleanVolatile", "(Ljava/lang/Object;J)Z", UnsafeNativeMethod.class.getDeclaredMethod("getBoolean", Frame.class));
        register("sun/misc/Unsafe", "getByteVolatile", "(Ljava/lang/Object;J)B", UnsafeNativeMethod.class.getDeclaredMethod("getByte", Frame.class));
        register("sun/misc/Unsafe", "getCharVolatile", "(Ljava/lang/Object;J)C", UnsafeNativeMethod.class.getDeclaredMethod("getChar", Frame.class));
        register("sun/misc/Unsafe", "getShortVolatile", "(Ljava/lang/Object;J)S", UnsafeNativeMethod.class.getDeclaredMethod("getShort", Frame.class));
        register("sun/misc/Unsafe", "getFloatVolatile", "(Ljava/lang/Object;J)F", UnsafeNativeMethod.class.getDeclaredMethod("getFloat", Frame.class));
        register("sun/misc/Unsafe", "getDoubleVolatile", "(Ljava/lang/Object;J)D", UnsafeNativeMethod.class.getDeclaredMethod("getDouble", Frame.class));
    }

    public static void putObject(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        java.lang.Object fields = vars.getRef(1).getData();
        int offset = (int) vars.getLong(2);
        Object x = vars.getRef(4);

        OperandStack stack = frame.getOperandStack();
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            slots.setRef(offset, x);
        } else {
            Object[] arr = (Object[]) fields;
            arr[offset] = x;
        }
    }

    public static void getShort(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        java.lang.Object fields = vars.getRef(1).getData();
        int offset = (int) vars.getLong(2);
        OperandStack stack = frame.getOperandStack();
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            stack.pushInt(slots.getInt(offset));
        } else {
            short[] arr = (short[]) fields;
            stack.pushInt(arr[offset] & 0xFFFF);
        }
    }

    public static void getFloat(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        java.lang.Object fields = vars.getRef(1).getData();
        int offset = (int) vars.getLong(2);

        OperandStack stack = frame.getOperandStack();
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            stack.pushFloat(slots.getFloat(offset));
        } else {
            float[] arr = (float[]) fields;
            stack.pushFloat(arr[offset]);
        }
    }

    public static void getDouble(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        java.lang.Object fields = vars.getRef(1).getData();
        int offset = (int) vars.getLong(2);
        OperandStack stack = frame.getOperandStack();
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            stack.pushDouble(slots.getDouble(offset));
        } else {
            double[] arr = (double[]) fields;
            stack.pushDouble(arr[offset]);
        }
    }

    public static void getLong(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        java.lang.Object fields = vars.getRef(1).getData();
        int offset = (int) vars.getLong(2);

        OperandStack stack = frame.getOperandStack();
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            stack.pushLong(slots.getLong(offset));
        } else {
            long[] arr = (long[]) fields;
            stack.pushLong(arr[offset]);
        }
    }

    public static void getBoolean(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        java.lang.Object fields = vars.getRef(1).getData();
        int offset = (int) vars.getLong(2);
        OperandStack stack = frame.getOperandStack();
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            stack.pushBoolean(slots.getInt(offset) == 1);
        } else {
            byte[] arr = (byte[]) fields;
            stack.pushBoolean(arr[offset] == 1);
        }
    }

    public static void getByte(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        java.lang.Object fields = vars.getRef(1).getData();
        int offset = (int) vars.getLong(2);
        OperandStack stack = frame.getOperandStack();
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            stack.pushInt(slots.getInt(offset));
        } else {
            byte[] refs = (byte[]) fields;
            stack.pushInt(refs[offset] & 0xFF);
        }
    }

    public static void getChar(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        java.lang.Object fields = vars.getRef(1).getData();
        int offset = (int) vars.getLong(2);

        OperandStack stack = frame.getOperandStack();
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            stack.pushInt(slots.getInt(offset));
        } else {
            char[] refs = (char[]) fields;
            stack.pushInt(refs[offset] & 0xFFFF);
        }
    }


    public static void allocateInstance(Frame frame) {
        LocalVars localVars = frame.getLocalVars();
        Object classObj = localVars.getRef(1);
        Klass klass = (Klass) classObj.getExtra();
        frame.getOperandStack().pushRef(klass.newObject());
    }

    public static void defineClass(Frame frame) {
        LocalVars localVars = frame.getLocalVars();
        Object nameObj = localVars.getRef(1);
        Object byteArr = localVars.getRef(2);
        int off = localVars.getInt(3);
        int len = localVars.getInt(4);
        String klassName = JvmUtil.javaStringToJvmString(nameObj).replace(".", "/");
        byte[] byteArray = (byte[]) byteArr.getData();
        byte[] klassData = new byte[len];
        System.arraycopy(byteArray, off, klassData, 0, len);
        Klass klass = KlassLoaderRegister.getAppKlassLoader().defineKlass(klassName, klassData, frame.getThread());
        frame.getOperandStack().pushRef(klass.getjClass());
    }

    public static void shouldBeInitialized(Frame frame) {
        LocalVars localVars = frame.getLocalVars();
        Object classObj = localVars.getRef(1);
        Klass klass = (Klass) classObj.getExtra();
        frame.getOperandStack().pushBoolean(!klass.isInitStarted());
    }

    public static void ensureClassInitialized(Frame frame) {
        LocalVars localVars = frame.getLocalVars();
        Object classObj = localVars.getRef(1);
        Klass klass = (Klass) classObj.getExtra();
        InstructionUtil.initClass(frame.getThread(), klass);
    }

    public static void staticFieldOffset(Frame frame) {
        Object jField = frame.getLocalVars().getRef(1);
        int offset = jField.getIntVar("slot");
        frame.getOperandStack().pushLong(offset);
    }

    public static void staticFieldBase(Frame frame) {
        Object fieldObj = frame.getLocalVars().getRef(1);
        Field field = null;
        if (fieldObj.getExtra() != null) {
            field = (Field) fieldObj.getExtra();
        } else {
            Object root = fieldObj.getRefVar("root", "Ljava/lang/reflect/Field;");
            field = (Field) root.getExtra();
        }
        Object res = new Object();
        res.setData(field.getKlass().getStaticVars());
        frame.getOperandStack().pushRef(res);
    }

    public static void arrayBaseOffset(Frame frame) {
        frame.getOperandStack().pushInt(0);
    }

    public static void arrayIndexScale(Frame frame) {
        frame.getOperandStack().pushInt(1);
    }

    public static void addressSize(Frame frame) {
        frame.getOperandStack().pushInt(8);
    }

    public static void objectFieldOffset(Frame frame) {
        Object jField = frame.getLocalVars().getRef(1);
        int offset = jField.getIntVar("slot");
        frame.getOperandStack().pushLong(offset);
    }

    /**
     * 对引用的CAS操作
     * <p>
     * 如何保证比较和赋值是一个原子化的操作呢？为什么不会出现比较后被其他线程再次修改引用值的情况？
     * 在物理CPU中CAS是单独一条指令，不会被其他线程干扰
     * 在本JVM中，native方法独占一条指令，并且本JVM仅使用了单Java线程，保证指令执行不会被打断
     *
     * @param frame
     */
    public static void compareAndSwapObject(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        //obj为一个非数组实例或数组实例
        Object obj = vars.getRef(1);
        //对于非数组实例fields为slots，数组实例为底层数组
        java.lang.Object fields = obj.getData();
        //要交换的引用的偏移量，在非数组实例中是slotId，在数组实例中是数组下标
        int offset = (int) vars.getLong(2);
        //预期引用
        Object expected = vars.getRef(4);
        //将要赋值的新引用
        Object newVal = vars.getRef(5);

        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            Object current = slots.getRef(offset);
            //对比预期的引用和目前的引用是否相等，如果相等则将新值赋值给偏移量对应的位置
            if (current == expected) {
                slots.setRef(offset, newVal);
                frame.getOperandStack().pushBoolean(true);
            } else {
                //目标引用已经被改变，比较失败
                frame.getOperandStack().pushBoolean(false);
            }
        } else {
            Object[] refs = (Object[]) fields;
            Object current = refs[offset];
            if (current == expected) {
                refs[offset] = newVal;
                frame.getOperandStack().pushBoolean(true);
            } else {
                frame.getOperandStack().pushBoolean(false);
            }
        }
    }

    public static void getInt(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        java.lang.Object fields = vars.getRef(1).getData();
        int offset = (int) vars.getLong(2);

        OperandStack stack = frame.getOperandStack();
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            stack.pushInt(slots.getInt(offset));
        } else {
            int[] refs = (int[]) fields;
            stack.pushInt(refs[offset]);
        }
    }

    public static void getObject(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        java.lang.Object fields = vars.getRef(1).getData();
        int offset = (int) vars.getLong(2);

        OperandStack stack = frame.getOperandStack();
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            stack.pushRef(slots.getRef(offset));
        } else {
            Object[] refs = (Object[]) fields;
            stack.pushRef(refs[offset]);
        }
    }

    public static void compareAndSwapInt(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        Object obj = vars.getRef(1);
        java.lang.Object fields = obj.getData();
        int offset = (int) vars.getLong(2);
        int expected = vars.getInt(4);
        int newVal = vars.getInt(5);
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            int current = slots.getInt(offset);
            if (current == expected) {
                slots.setInt(offset, newVal);
                frame.getOperandStack().pushBoolean(true);
            } else {
                frame.getOperandStack().pushBoolean(false);
            }
        } else {
            int[] refs = (int[]) fields;
            int current = refs[offset];
            if (current == expected) {
                refs[offset] = newVal;
                frame.getOperandStack().pushBoolean(true);
            } else {
                frame.getOperandStack().pushBoolean(false);
            }
        }
    }

    public static void compareAndSwapLong(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        Object obj = vars.getRef(1);
        java.lang.Object fields = obj.getData();
        int offset = (int) vars.getLong(2);
        long expected = vars.getLong(4);
        long newVal = vars.getLong(6);
        if (fields instanceof Slots) {
            Slots slots = (Slots) fields;
            long current = slots.getLong(offset);
            if (current == expected) {
                slots.setLong(offset, newVal);
                frame.getOperandStack().pushBoolean(true);
            } else {
                frame.getOperandStack().pushBoolean(false);
            }
        } else {
            long[] refs = (long[]) fields;
            long current = refs[offset];
            if (current == expected) {
                refs[offset] = newVal;
                frame.getOperandStack().pushBoolean(true);
            } else {
                frame.getOperandStack().pushBoolean(false);
            }
        }
    }


    public static void allocateMemory(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        long bytes = vars.getLong(1);

        long address = Malloc.allocate(bytes);
        frame.getOperandStack().pushLong(address);
    }

    public static void reallocateMemory(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        long address = vars.getLong(1);
        long bytes = vars.getLong(3);

        long newAddress = Malloc.reallocate(address, bytes);
        frame.getOperandStack().pushLong(newAddress);
    }

    public static void freeMemory(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        long address = vars.getLong(1);
        Malloc.free(address);
    }

    public static void mem_getByte(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        long address = vars.getLong(1);
        Malloc.MemIndex memIndex = Malloc.memoryAt(address);
        frame.getOperandStack().pushInt(0xFF & memIndex.getMem()[(int) memIndex.getStart()]);
    }

    public static void mem_putLong(Frame frame) {
        LocalVars vars = frame.getLocalVars();
        long address = vars.getLong(1);
        long value = vars.getLong(3);

        Malloc.putLong(address, value);
    }

}
